import torch
import os

import json
from .ops import *
from .bert import *
from .config import ModelConfig
from .cache_utils import load_model_state

__all__ = ['DeBERTa']

class DeBERTa(torch.nn.Module):
  def __init__(self, config=None, pre_trained=None):
    super().__init__()
    if config:
      self.z_steps = getattr(config, 'z_steps', 0)
    else:
      self.z_steps = 0

    state = None
    if pre_trained is not None:
      state, config = load_model_state(pre_trained)
    self.embeddings = BertEmbeddings(config)
    self.encoder = BertEncoder(config)
    self.config = config
    self.pre_trained = pre_trained
    self.apply_state(state)

  def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, position_ids = None, return_att = False):
    if attention_mask is None:
      attention_mask = torch.ones_like(input_ids)
    if token_type_ids is None:
      token_type_ids = torch.zeros_like(input_ids)

    embedding_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, attention_mask)
    encoded_layers = self.encoder(embedding_output,
                   attention_mask,
                   output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
    if return_att:
      encoded_layers, att_matrixs = encoded_layers

    if self.z_steps>1:
      hidden_states = encoded_layers[-2]
      layers = [self.encoder.layer[-1] for _ in range(z_steps)]
      query_states = encoded_layers[-1]
      rel_embeddings = self.encoder.get_rel_embedding()
      attention_mask = self.encoder.get_attention_mask(attention_mask)
      rel_pos = self.encoder.get_rel_pos(embedding_output)
      for layer in layers[1:]:
        query_states = layer(hidden_states, attention_mask, return_att=False, query_states = query_states, relative_pos=rel_pos, rel_embeddings=rel_embeddings)
        encoded_layers.append(query_states)

    if not output_all_encoded_layers:
      encoded_layers = encoded_layers[-1:]

    if return_att:
      return encoded_layers, att_matrixs
    return encoded_layers

  def apply_state(self, state = None):
    if self.pre_trained is None and state is None:
      return
    if state is None:
      state, config = load_model_state(pre_trained)
      self.config = config

    def key_match(key, s):
      c = [k for k in s if key in k]
      assert len(c)==1, c
      return c[0]
    current = self.state_dict()
    for c in current.keys():
      current[c] = state[key_match(c, state.keys())]
    self.load_state_dict(current)
